from .base_reasoner import BaseReasoner, ReasoningNode
import asyncio
import argparse
import json
import os
import re
import time
import traceback
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Any, Union
import random
import openai
from collections import defaultdict
from datetime import datetime

class StrategyQAReasoner(BaseReasoner):
    def __init__(self):
        super().__init__("StrategyQA")
        self.config.dataset_path = "datasets/StrategyQA.json"
    
    async def load_problems(self, start_idx: int, end_idx: int) -> List[Dict]:
        """Load StrategyQA problems from dataset"""
        try:
            with open(self.config.dataset_path, "r", encoding="utf-8") as f:
                data = json.load(f)
                return data[start_idx:end_idx]
        except Exception as e:
            print(f"Error loading dataset: {str(e)}")
            return []
    
    async def execute_workflow(self, problem: Dict[str, Any]) -> Dict[str, Any]:
        """Execute full reasoning workflow for a StrategyQA problem"""
        try:
            question = problem["question"]
            facts = problem.get("facts", [])
            
            # Step 1: Create root node
            root = self._create_node(
                question=question,
                facts=facts,
                conditions={},
                path=[],
                method={"description": "Original problem"}
            )
            self._log_step("step1", root.node_id, {"question": question})
            
            # Step 2: Extract conditions
            conditions = await self._condition_extraction(question, facts)
            root.conditions = conditions
            self._log_step("step2", root.node_id, {"conditions": conditions})
            
            # Step 3: Explore solution methods
            methods = await self._tree_explorer(question, facts)
            self._log_step("step3", root.node_id, {"methods": methods})
            
            # Step 4: Create method nodes
            method_nodes = []
            for method in methods[:self.config.beam_width]:
                node = self._create_node(
                    path=[root.node_id],
                    question=question,
                    facts=facts,
                    method=method,
                    conditions=root.conditions,
                    score=method.get("score", 0),
                    parent_id=root.node_id
                )
                root.children.append(node.node_id)
                method_nodes.append(node)
                self._log_step("step4", node.node_id, {"method": method})
            
            # Step 5: Check classification for best method
            best_method_node = max(method_nodes, key=lambda x: x.score)
            classification = await self._adaptive_domain_decomposition(
                best_method_node.method["description"],
                question,
                facts
            )
            self._log_step("step5", best_method_node.node_id, {"classification": classification})
            
            if classification["need_classify"]:
                # Step 6: Create classification nodes
                for case in classification["cases"]:
                    combined_conditions = {
                        "explicit": best_method_node.conditions.get("explicit", []).copy(),
                        "implicit": best_method_node.conditions.get("implicit", []).copy()
                    }
                    
                    for k, v in case["conditions"].items():
                        if k in combined_conditions:
                            combined_conditions[k].append(v)
                        else:
                            combined_conditions.setdefault("implicit", []).append(f"{k}: {v}")
                    
                    node = self._create_node(
                        path=best_method_node.path + [best_method_node.node_id],
                        question=question,
                        facts=facts,
                        method=best_method_node.method,
                        conditions=combined_conditions,
                        score=best_method_node.score,
                        parent_id=best_method_node.node_id
                    )
                    best_method_node.children.append(node.node_id)
                    self.temp_list.append(node.node_id)
                    self._log_step("step6", node.node_id, {"case": case})
            else:
                self.temp_list.append(best_method_node.node_id)
            
            # Step 7: Solve nodes
            solutions = []
            for node_id in self.temp_list:
                solution = await self._resolution(node_id)
                if solution:
                    solutions.append(solution)
                    self._log_step("step7", node_id, {"solution": solution})
            
            # Step 8: Aggregate answers
            final_answer = await self._aggregation(solutions)
            self._log_step("step8", "system", {"final_answer": final_answer})
            
            return {
                "status": "success",
                "final_answer": final_answer,
                "nodes": self.nodes,
                "logs": self.logs,
                "token_usage": self.llm.token_counts
            }
            
        except Exception as e:
            traceback.print_exc()
            return {
                "status": "error",
                "message": str(e),
                "logs": self.logs
            }
    
    async def _condition_extraction(self, question: str, facts: List[str]) -> Dict[str, Any]:
        """Extract conditions from problem and facts"""
        facts_text = "\n".join([f"- {fact}" for fact in facts])
        
        prompt = f"""You are a top expert in analytical reasoning and factual analysis.  
You are precise, rational, and skeptical.  
You always examine each statement carefully, identify key facts, and evaluate logical validity step by step.  
You avoid unwarranted assumptions, think in terms of logical consequences, and reason carefully based on available facts.  
You aim to reach conclusions based only on evidence and logic.  
You THINK SLOWLY, CAREFULLY, AND LOGICALLY.
Analyze this StrategyQA question and extract key conditions:

Question: {question}
Facts:
{facts_text}

Identify:
1. Explicit conditions (directly stated in facts)
2. Implicit conditions (logical implications from facts)
3. Key terms and their relationships
4. Any conditional statements or dependencies

Output JSON format:
{{
    "explicit": ["list", "of", "conditions"],
    "implicit": ["list", "of", "conditions"],
    "key_terms": ["term1", "term2"],
    "notes": "Analysis summary"
}}"""
        
        for attempt in range(self.config.max_retries):
            try:
                response = await self.llm.generate(prompt, response_format="json_object")
                return json.loads(response)
            except:
                continue
        
        return {
            "explicit": [],
            "implicit": [],
            "key_terms": [],
            "notes": "Failed to extract conditions"
        }
    
    async def _tree_explorer(self, question: str, facts: List[str]) -> List[Dict]:
        """Step 3: Explore diverse solution methods for StrategyQA"""
        facts_text = "\n".join([f"- {fact}" for fact in facts])
        
        prompt = f"""You are a top expert in analytical reasoning and factual analysis.  
You are precise, rational, and skeptical.  
You always examine each statement carefully, identify key facts, and evaluate logical validity step by step.  
You avoid unwarranted assumptions, think in terms of logical consequences, and reason carefully based on available facts.  
You aim to reach conclusions based only on evidence and logic.  
You THINK SLOWLY, CAREFULLY, AND LOGICALLY.
Generate 3 distinct solution approaches for this question:

Question: {question}
Facts:
{facts_text}

For each approach, provide:
- Clear description of the reasoning strategy
- Key steps to implement the approach
- Confidence score (0-100) based on:
  * Logical soundness
  * Coverage of facts
  * Appropriate use of deductive/inductive reasoning
  * Clarity of reasoning steps

Output JSON format:
{{
    "methods": [
        {{
            "description": "Approach description",
            "steps": ["step1", "step2"],
            "score": 0-100,
            "score_reason": "Scoring justification"
        }}
    ]
}}"""
        
        for attempt in range(self.config.max_retries):
            try:
                response = await self.llm.generate(prompt, response_format="json_object")
                response = response.strip()
                
                if response.startswith("```json"):
                    response = response[7:-3].strip()
                elif response.startswith("```"):
                    response = response[3:-3].strip()
                
                data = json.loads(response)
                
                if not isinstance(data, dict) or "methods" not in data:
                    raise ValueError("Invalid structure: missing 'methods' key")
                    
                methods = data["methods"]
                if len(methods) < 2:
                    raise ValueError(f"Expected at least 2 methods, got {len(methods)}")
                    
                required_keys = {"description", "steps", "score", "score_reason"}
                for method in methods:
                    if not all(k in method for k in required_keys):
                        raise ValueError("Missing required keys in method")
                    if not isinstance(method["steps"], list):
                        raise ValueError("Steps must be a list")
                        
                return sorted(methods, key=lambda x: -x["score"])
                
            except (json.JSONDecodeError, ValueError, KeyError) as e:
                print(f"Attempt {attempt + 1} failed: {str(e)}")
                if attempt == self.config.max_retries - 1:
                    print(f"Final failed response: {response}")
                    return []
                continue
                
        return []
    
    async def _adaptive_domain_decomposition(self, method: str, question: str, facts: List[str]) -> Dict[str, Any]:
        """Step 5: Determine if classification needed for StrategyQA"""
        facts_text = "\n".join([f"- {fact}" for fact in facts])
        
        prompt = f"""You are a top expert in analytical reasoning and factual analysis.  
You are precise, rational, and skeptical.  
You always examine each statement carefully, identify key facts, and evaluate logical validity step by step.  
You avoid unwarranted assumptions, think in terms of logical consequences, and reason carefully based on available facts.  
You aim to reach conclusions based only on evidence and logic.  
You THINK SLOWLY, CAREFULLY, AND LOGICALLY.
Determine if this solution approach requires case classification:

Solution Approach: {method}
Question: {question}
Facts:
{facts_text}

Consider:
1. Does the question contain multiple scenarios or cases?
2. Are there conditional facts that create distinct possibilities?
3. Would different interpretations of facts lead to different answers?
4. Are there different ways to combine the facts that affect the answer?

If classification needed, provide:
- Comprehensive case descriptions
- Precise conditions for each case
- Expected outcomes for each case

Output JSON format:
{{
    "need_classify": true/false,
    "reason": "Classification rationale",
    "cases": [
        {{
            "description": "Case description",
            "conditions": {{"parameter": "value_range"}}
        }}
    ]
}}"""
        
        try:
            response = await self.llm.generate(prompt, response_format="json_object")
            data = json.loads(response)
            return data
        except:
            return {
                "need_classify": False,
                "reason": "Analysis failed",
                "cases": []
            }
    
    async def _resolution(self, node_id: str) -> Optional[Dict[str, Any]]:
        """Step 7: Solve individual reasoning node"""
        node = self.nodes[node_id]
        
        context = f"Question: {node.question}\nFacts:\n"
        context += "\n".join([f"- {fact}" for fact in node.facts])
        
        context += f"\nSolution Approach: {node.method['description']}\n"
        context += f"conditions: {json.dumps(node.conditions, indent=2)}\n"
        
        prompt = f"""You are a top expert in analytical reasoning and factual analysis.  
You are precise, rational, and skeptical.  
You always examine each statement carefully, identify key facts, and evaluate logical validity step by step.  
You avoid unwarranted assumptions, think in terms of logical consequences, and reason carefully based on available facts.  
You aim to reach conclusions based only on evidence and logic.  
You THINK SLOWLY, CAREFULLY, AND LOGICALLY.
Solve this question using the specified approach:

{context}

Reasoning Steps:
1. Strictly follow the provided approach: {node.method['description']}
2. Execute each step: {', '.join(node.method['steps'])}
3. Consider all conditions
4. Evaluate each fact systematically
5. Provide clear justification for your conclusion
6. Select the best answer (true or false)

Output Requirements:
- End your response with: "Final Answer: [true/false]"
- Use \boxed{{[true/false]}} to denote your answer
- Your answer must be either true or false"""
        
        response = await self.llm.generate(prompt)
        answer = self._extract_answer(response)
        
        if answer is not None:
            node.answer = answer
            node.state = "solved"
            return {
                "node_id": node_id,
                "response": response,
                "answer": answer
            }
        return None
    
    async def _aggregation(self, solutions: List[Dict[str, Any]]) -> bool:
        """Step 8: Aggregate answers from multiple nodes"""
        if not solutions:
            return False
        
        if len(solutions) == 1:
            return solutions[0]["answer"]
        
        answers = [s["answer"] for s in solutions]
        if len(set(answers)) == 1:
            return answers[0]
        
        solutions_text = ""
        for i, sol in enumerate(solutions):
            node = self.nodes[sol["node_id"]]
            solutions_text += f"\n\nSolution {i+1} (Node {sol['node_id']}):"
            solutions_text += f"\nApproach: {node.method['description']}"
            solutions_text += f"\nconditions: {json.dumps(node.conditions, indent=2)}"
            solutions_text += f"\nAnswer: {sol['answer']}"
            solutions_text += f"\nReasoning Excerpt:\n{sol['response'][:]}..."
        
        prompt = f"""You are a top expert in analytical reasoning and factual analysis.  
You are precise, rational, and skeptical.  
You always examine each statement carefully, identify key facts, and evaluate logical validity step by step.  
You avoid unwarranted assumptions, think in terms of logical consequences, and reason carefully based on available facts.  
You aim to reach conclusions based only on evidence and logic.  
You THINK SLOWLY, CAREFULLY, AND LOGICALLY.
Synthesize these approaches:

{solutions_text}

Instructions:
1. Analyze all solutions and their approaches
2. Identify the most reliable reasoning
3. Verify consistency with conditions
4. Select the best overall answer
5. Output format: \boxed{{[true/false]}}"""
        
        response = await self.llm.generate(prompt)
        answer = self._extract_answer(response)
        return answer if answer is not None else False
    
    def save_results(self, result: Dict[str, Any], problem: Dict[str, Any]) -> Dict[str, Any]:
        serialized_nodes = {}
        for node_id, node in self.nodes.items():
            serialized_nodes[node_id] = {
                "node_id": node.node_id,
                "question": node.question,
                "facts": node.facts,
                "method": node.method,
                "conditions": node.conditions,
                "answer": node.answer,
                "state": node.state,
                "score": node.score
            }
        
        selected_answer = result.get("final_answer", False)
        correct_answer = problem.get("answer", False)
        is_correct = self.verify_answer(problem, selected_answer)
        verification = {
            "is_correct": is_correct,
            "correct_answer": correct_answer,
            "given_answer": result.get("final_answer")
        }

        return {
            "problem": problem,
            "result": {
                "final_answer": selected_answer,
                "correct_answer": correct_answer,
                "is_correct": is_correct,
                "nodes": serialized_nodes,
                "token_usage": result.get("token_usage", [0, 0])
            },
            "verification": verification
        }
    def _extract_answer(self, text: str) -> Optional[bool]:
        """Extract boolean answer from response text"""
        true_pattern = r'(?i)\b(true|yes|correct|right)\b'
        false_pattern = r'(?i)\b(false|no|incorrect|wrong)\b'
        
        if re.search(true_pattern, text):
            return True
        elif re.search(false_pattern, text):
            return False
        
        return None

    def verify_answer(self, problem: Dict[str, Any], selected_answer: bool) -> bool:
        """Verify if selected answer matches correct answer"""
        correct_answer = problem.get("answer", None)
        return selected_answer == correct_answer